#include "quality_evaluation.hpp"
#include <fstream>
/**         
* Brenner梯度方法 
*   
* Inputs:   
* @param image:  
* Return: double    
*/
double brenner(cv::Mat &image)
{
	assert(!image.empty());
 
	cv::Mat gray_img;
	if (image.channels() == 3){
		cv::cvtColor(image, gray_img, cv::COLOR_BGR2GRAY);
	}
 
	double result = .0f;
	for (int i = 0; i < gray_img.rows; ++i){
		uchar *data = gray_img.ptr<uchar>(i);
		for (int j = 0; j < gray_img.cols - 2; ++j){
			result += pow(data[j + 2] - data[j], 2);
		}
	}
 
	return result/gray_img.total();
}

/**         
* Tenengrad梯度方法 
*   
* Inputs:   
* @param image:  
* Return: double    
*/
double tenengard(cv::Mat &image)
{
	assert(!image.empty());
 
	cv::Mat gray_img, sobel_x, sobel_y, G;
	if (image.channels() == 3){
		cv::cvtColor(image, gray_img, cv::COLOR_BGR2GRAY);
	}
 
	//分别计算x/y方向梯度
	cv::Sobel(gray_img, sobel_x, CV_32FC1, 1, 0);
	cv::Sobel(gray_img, sobel_y, CV_32FC1, 0, 1);
	cv::multiply(sobel_x, sobel_x, sobel_x);
	cv::multiply(sobel_y, sobel_y, sobel_y);
	cv::Mat sqrt_mat = sobel_x + sobel_y;
	cv::sqrt(sqrt_mat, G);
 
	return cv::mean(G)[0];
}

/**         
* Laplacian 梯度函数 
*   
* Inputs:   
* @param image:  
* Return: double    
*/
double laplacian(cv::Mat &image)
{
	assert(!image.empty());
 
	cv::Mat gray_img, lap_image;
	if (image.channels() == 3){
		cv::cvtColor(image, gray_img, cv::COLOR_BGR2GRAY);
	}
 
	cv::Laplacian(gray_img, lap_image, CV_32FC1);
	lap_image=cv::abs(lap_image);
 
	return cv::mean(lap_image)[0];
}

/**         
* SMD（灰度方差）函数 
*   
* Inputs:   
* @param image:  
* Return: double    
*/
double smd(cv::Mat &image)
{
	assert(!image.empty());
 
	cv::Mat gray_img, smd_image_x, smd_image_y, G;
	if (image.channels() == 3){
		cv::cvtColor(image, gray_img, cv::COLOR_BGR2GRAY);
	}
 
	cv::Mat kernel_x(3, 3, CV_32F, cv::Scalar(0));
	kernel_x.at<float>(1, 2) = -1.0;
	kernel_x.at<float>(1, 1) = 1.0;
	cv::Mat kernel_y(3, 3, CV_32F, cv::Scalar(0));
	kernel_y.at<float>(0, 1) = -1.0;
	kernel_y.at<float>(1, 1) = 1.0;
	cv::filter2D(gray_img, smd_image_x, gray_img.depth(), kernel_x);
	cv::filter2D(gray_img, smd_image_y, gray_img.depth(), kernel_y);
 
	smd_image_x = cv::abs(smd_image_x);
	smd_image_y = cv::abs(smd_image_y);
	G = smd_image_x + smd_image_y;
 
	return cv::mean(G)[0];
}

/**         
* SMD2 （灰度方差乘积）函数
*   
* Inputs:   
* @param image:  
* Return: double    
*/
double smd2(cv::Mat &image)
{
	assert(!image.empty());
 
	cv::Mat gray_img, smd_image_x, smd_image_y, G;
	if (image.channels() == 3){
		cv::cvtColor(image, gray_img, cv::COLOR_BGR2GRAY);
	}
 
	cv::Mat kernel_x(3, 3, CV_32F, cv::Scalar(0));
	kernel_x.at<float>(1, 2) = -1.0;
	kernel_x.at<float>(1, 1) = 1.0;
	cv::Mat kernel_y(3, 3, CV_32F, cv::Scalar(0));
	kernel_y.at<float>(1, 1) = 1.0;
	kernel_y.at<float>(2, 1) = -1.0;
	cv::filter2D(gray_img, smd_image_x, gray_img.depth(), kernel_x);
	cv::filter2D(gray_img, smd_image_y, gray_img.depth(), kernel_y);
 
	smd_image_x = cv::abs(smd_image_x);
	smd_image_y = cv::abs(smd_image_y);
	cv::multiply(smd_image_x, smd_image_y, G);
 
	return cv::mean(G)[0];
}

/**         
* 能量梯度函数 
*   
* Inputs:   
* @param image:  
* Return: double    
*/
double energy_gradient(cv::Mat &image)
{
	assert(!image.empty());
 
	cv::Mat gray_img, smd_image_x, smd_image_y, G;
	if (image.channels() == 3){
		cv::cvtColor(image, gray_img, cv::COLOR_BGR2GRAY);
	}
 
	cv::Mat kernel_x(3, 3, CV_32F, cv::Scalar(0));
	kernel_x.at<float>(1, 2) = -1.0;
	kernel_x.at<float>(1, 1) = 1.0;
	cv::Mat kernel_y(3, 3, CV_32F, cv::Scalar(0));
	kernel_y.at<float>(1, 1) = 1.0;
	kernel_y.at<float>(2, 1) = -1.0;
	cv::filter2D(gray_img, smd_image_x, gray_img.depth(), kernel_x);
	cv::filter2D(gray_img, smd_image_y, gray_img.depth(), kernel_y);
 
	cv::multiply(smd_image_x, smd_image_x, smd_image_x);
	cv::multiply(smd_image_y, smd_image_y, smd_image_y);
	G = smd_image_x + smd_image_y;
 
	return cv::mean(G)[0];
}

/**         
* EAV点锐度算法函数 
*   
* Inputs:   
* @param image:  
* Return: double    
*/
double eav(cv::Mat &image)
{
	assert(!image.empty());
 
	cv::Mat gray_img, smd_image_x, smd_image_y, G;
	if (image.channels() == 3){
		cv::cvtColor(image, gray_img, cv::COLOR_BGR2GRAY);
	}
 
	double result = .0f;
	for (int i = 1; i < gray_img.rows-1; ++i){
		uchar *prev = gray_img.ptr<uchar>(i - 1);
		uchar *cur = gray_img.ptr<uchar>(i);
		uchar *next = gray_img.ptr<uchar>(i + 1);
		for (int j = 0; j < gray_img.cols; ++j){
			result += (abs(prev[j - 1] - cur[i])*0.7 + abs(prev[j] - cur[j]) + abs(prev[j + 1] - cur[j])*0.7 +
				abs(next[j - 1] - cur[j])*0.7 + abs(next[j] - cur[j]) + abs(next[j + 1] - cur[j])*0.7 +
				abs(cur[j - 1] - cur[j]) + abs(cur[j + 1] - cur[j]));
		}
	}
	
	return result / gray_img.total();
}

/**         
* 误差灵敏度分析和结构相似度分析 
*   
* Inputs:   
* @param i1:  
* @param i2:  
* Return: double    
*/
double ssim(cv::Mat &i1, cv::Mat & i2)
{
	const double C1 = 6.5025, C2 = 58.5225;
	int d = CV_32F;
	cv::Mat I1, I2;
	i1.convertTo(I1, d);
	i2.convertTo(I2, d);
	cv::Mat I1_2 = I1.mul(I1);
	cv::Mat I2_2 = I2.mul(I2);
	cv::Mat I1_I2 = I1.mul(I2);
	cv::Mat mu1, mu2;
	GaussianBlur(I1, mu1, cv::Size(11, 11), 1.5);
	GaussianBlur(I2, mu2, cv::Size(11, 11), 1.5);
	cv::Mat mu1_2 = mu1.mul(mu1);
	cv::Mat mu2_2 = mu2.mul(mu2);
	cv::Mat mu1_mu2 = mu1.mul(mu2);
	cv::Mat sigma1_2, sigam2_2, sigam12;
	GaussianBlur(I1_2, sigma1_2, cv::Size(11, 11), 1.5);
	sigma1_2 -= mu1_2;
	GaussianBlur(I2_2, sigam2_2, cv::Size(11, 11), 1.5);
	sigam2_2 -= mu2_2;
	GaussianBlur(I1_I2, sigam12, cv::Size(11, 11), 1.5);
	sigam12 -= mu1_mu2;
	cv::Mat t1, t2, t3;
	t1 = 2 * mu1_mu2 + C1;
	t2 = 2 * sigam12 + C2;
	t3 = t1.mul(t2);
 
	t1 = mu1_2 + mu2_2 + C1;
	t2 = sigma1_2 + sigam2_2 + C2;
	t1 = t1.mul(t2);
 
	cv::Mat ssim_map;
	divide(t3, t1, ssim_map);
	cv::Scalar mssim = cv::mean(ssim_map);
 
	double ssim = (mssim.val[0] + mssim.val[1] + mssim.val[2]) / 3;
	return ssim;
}
 
/**         
* NRSS梯度结构相似度 
*   
* Inputs:   
* @param image:  
* Return: double    
*/
double nrss(cv::Mat &image)
{
	assert(!image.empty());
 
	cv::Mat gray_img, Ir, G, Gr;
	if (image.channels() == 3){
		cv::cvtColor(image, gray_img, cv::COLOR_BGR2GRAY);
	}
 
	//构造参考图像
	cv::GaussianBlur(gray_img, Ir, cv::Size(7, 7), 6, 6);
 
	//提取图像和参考图像的梯度信息
	cv::Sobel(gray_img, G, CV_32FC1, 1, 1);//计算原始图像sobel梯度
	cv::Sobel(Ir, Gr, CV_32FC1, 1, 1);//计算构造函数的sobel梯度
 
	//找出梯度图像 G 中梯度信息最丰富的 N 个图像块，n=64(即划分为8x8的大小)
	//计算每个小方块的宽/高
	int block_cols = G.cols * 2 / 9;
	int block_rows = G.rows * 2 / 9;
	//获取方差最大的block
	cv::Mat best_G,best_Gr;
	float max_stddev = .0f;
	int pos = 0;
	for (int i = 0; i < 64; ++i){
		int left_x = (i % 8)*(block_cols / 2);
		int left_y = (i / 8)*(block_rows / 2);
		int right_x = left_x + block_cols;
		int right_y = left_y + block_rows;
 
		if (left_x < 0) left_x = 0;
		if (left_y < 0) left_y = 0;
		if (right_x >= G.cols) right_x = G.cols - 1;
		if (right_y >= G.rows) right_y = G.rows - 1;
 
		cv::Rect roi(left_x,left_y,right_x-left_x,right_y-left_y);
		cv::Mat temp=G(roi).clone();
		cv::Scalar mean,stddev;
		cv::meanStdDev(temp, mean, stddev);
		if (stddev.val[0]>max_stddev){
			max_stddev = static_cast<float>(stddev.val[0]);
			pos = i;
			best_G = temp;
			best_Gr = Gr(roi).clone();
		}
	}
	
	//计算结构清晰度NRSS
	double result = 1 - ssim(best_G, best_Gr);
 
	return result;
}

/**         
* 模糊检测 
*   
* Inputs:   
* @param gray_img:  
* @param blur_mean:  
* @param blur_ratio:  
* Return: void    
*/
void comput_blur_IQA(cv::Mat &src, float &blur_mean, float &blur_ratio)
{
	cv::Mat gray_img = src.clone();
	//计算水平/竖直差值获取梯度图
	cv::Mat grad_h, grad_v;
	cv::Mat kernel_h = cv::Mat::zeros(cv::Size(3, 3), CV_32FC1);
	kernel_h.at<float>(0, 1) = -1;
	kernel_h.at<float>(2, 1) = 1;
	cv::filter2D(gray_img, grad_h, CV_32FC1, kernel_h);
	cv::Mat kernel_v = cv::Mat::zeros(cv::Size(3, 3), CV_32FC1);
	kernel_v.at<float>(1, 0) = -1;
	kernel_v.at<float>(1, 2) = 1;
	cv::filter2D(gray_img, grad_v, CV_32FC1, kernel_v);
 
	//获取候选边缘点
	//筛选条件：D_h > D_mean
	float mean = static_cast<float>(cv::mean(grad_v)[0]);
	cv::Mat mask = grad_h > mean;
	mask = mask / 255;
	mask.convertTo(mask, CV_32FC1);
	cv::Mat C_h;
	cv::multiply(grad_h, mask, C_h);
 
 
	//进一步筛选边缘点
	//筛选条件：C_h(x,y) > C_h(x,y-1) and C_h(x,y) > C_h(x,y+1)
	cv::Mat edge = cv::Mat::zeros(C_h.rows, C_h.cols, CV_8UC1);
	for (int i = 1; i < C_h.rows-1; ++i){
		float *prev = C_h.ptr<float>(i - 1);
		float *cur = C_h.ptr<float>(i);
		float *next = C_h.ptr<float>(i + 1);
		uchar *data = edge.ptr<uchar>(i);
		for (int j = 0; j < C_h.cols; ++j){
			if (prev[j] < cur[j] && next[j] < cur[j]){
				data[j] = 1;
			}
		}
	}
 
	//检测边缘点是否模糊
	//获取inverse blur
	cv::Mat A_h = grad_h / 2;
	cv::Mat BR_h=cv::Mat(gray_img.size(),CV_32FC1);
	gray_img.convertTo(gray_img, CV_32FC1);
	cv::absdiff(gray_img, A_h, BR_h);
	cv::divide(BR_h, A_h, BR_h);
	cv::Mat A_v = grad_v / 2;
	cv::Mat BR_v;
	cv::absdiff(gray_img, A_v, BR_v);
	cv::divide(BR_v, A_v, BR_v);
 
	cv::Mat inv_blur = cv::Mat::zeros(BR_v.rows, BR_v.cols, CV_32FC1);
	for (int i = 0; i < inv_blur.rows; ++i){
		float *data_v = BR_v.ptr<float>(i);
		float *data = inv_blur.ptr<float>(i);
		float *data_h = BR_h.ptr<float>(i);
		for (int j = 0; j < inv_blur.cols; ++j){
			data[j] = data_v[j]>data_h[j] ? data_v[j] : data_h[j];
		}
	}
	//获取最终模糊点
	cv::Mat blur = inv_blur < 0.1 / 255;
	blur.convertTo(blur, CV_32FC1);
 
	//计算边缘模糊的均值和比例
	int sum_inv_blur = cv::countNonZero(inv_blur);
	int sum_blur = cv::countNonZero(blur);
	int sum_edge = cv::countNonZero(edge);
	blur_mean = static_cast<float>(sum_inv_blur) / sum_blur;
	blur_ratio = static_cast<float>(sum_blur) / sum_edge;
}
 
/**         
* 噪点检测 
*   
* Inputs:   
* @param gray_img:  
* @param noise_mean:  
* @param noise_ratio:  
* Return: void    
*/
void compute_noise_IQA(cv::Mat &gray_img, float &noise_mean, float &noise_ratio)
{
	//均值滤波去除噪声对边缘检测的影响
	cv::Mat blur_img;
	cv::blur(gray_img, blur_img, cv::Size(3, 3));
 
	//进行竖直方向边缘检测
	cv::Mat grad_h, grad_v;
	cv::Mat kernel_h = cv::Mat::zeros(cv::Size(3, 3), CV_32FC1);
	kernel_h.at<float>(0, 1) = -1;
	kernel_h.at<float>(2, 1) = 1;
	cv::filter2D(gray_img, grad_h, CV_32FC1, kernel_h);
	cv::Mat kernel_v = cv::Mat::zeros(cv::Size(3, 3), CV_32FC1);
	kernel_v.at<float>(1, 0) = -1;
	kernel_v.at<float>(1, 2) = 1;
	cv::filter2D(gray_img, grad_v, CV_32FC1, kernel_v);
 
	//筛选候选点
	//水平/竖直梯度的均值
	float D_h_mean = .0f, D_v_mean = .0f;
	D_h_mean = static_cast<float>(cv::mean(grad_h)[0]);
	D_v_mean = static_cast<float>(cv::mean(grad_v)[0]);
 
	//获取候选噪声点
	cv::Mat N_cand = cv::Mat::zeros(gray_img.rows, gray_img.cols, CV_32FC1);
	for (int i = 0; i < gray_img.rows; ++i){
		float *data_h = grad_h.ptr<float>(i);
		float *data_v = grad_v.ptr<float>(i);
		float *data = N_cand.ptr<float>(i);
		for (int j = 0; j < gray_img.cols; ++j){
			if (data_v[j] < D_v_mean && data_h[j] < D_h_mean){
				data[j] = data_v[j]>data_h[j] ? data_v[j] : data_h[j];
			}
		}
	}
 
	//最终的噪声点
	float N_cand_mean = static_cast<float>(cv::mean(N_cand)[0]);
	cv::Mat mask = (N_cand>N_cand_mean)/255;
	mask.convertTo(mask, CV_32FC1);
	cv::Mat N;
	cv::multiply(N_cand, mask, N);
 
	//计算噪声的均值和比率
	float sum_noise = static_cast<float>(cv::sum(N)[0]);
	int sum_noise_cnt = cv::countNonZero(N);
	noise_mean = sum_noise / (sum_noise_cnt + 0.0001);
	noise_ratio = static_cast<float>(sum_noise_cnt) / N.total();
}
 
/**         
* 对模糊和噪声进行无参考图像质量评估  
*   
* Inputs:   
* @param image:  
* Return: double    
*/
double blur_noise_IQA(cv::Mat &image)
{
	assert(!image.empty());
 
	cv::Mat gray_img=cv::Mat(image.size(),CV_8UC1);
	if (image.channels() == 3){
		cv::cvtColor(image, gray_img, cv::COLOR_BGR2GRAY);
	}
 
	//1、模糊检测
	float blur_mean = 0.f, blur_ratio = 0.f;
	comput_blur_IQA(gray_img, blur_mean, blur_ratio);
 
	//2、噪声点检测
	float noise_mean = 0.f, noise_ratio = 0.f;
	compute_noise_IQA(gray_img, noise_mean, noise_ratio);
 
	//3、噪声和模糊的组合
	double result = 1 - (blur_mean + 0.95*blur_ratio + 0.3*noise_mean + 0.75*noise_ratio);
	return result;
}

double qe::getQuality(cv::Mat &image, QualityEvaluationType type)
{
	switch (type)
	{
	case BLUR_NOISE_IQA:
		return blur_noise_IQA(image);
	case BRENNER:
		return brenner(image);
	case EAV:
		return eav(image);
	case ENERGY_GRADIENT:
		return energy_gradient(image);
	case LAPLACIAN:
		return laplacian(image);
	case NRSS:
		return nrss(image);
	case SMD:
		return smd(image);
	case SMD2:
		return smd2(image);
	case TENENGARD:
		return tenengard(image);
	default:
		return 0;
	}
}

void qe::cropImage(cv::Mat &inputs, cv::Mat &outputs, const float& rate)
{
	int width = inputs.cols * rate;
	int height = inputs.rows * rate;

	// 按照图片中心裁剪为原图的1/4
    cv::Rect roi(width, height, 2 * width, 2 * height);

	outputs = inputs(roi);

}

void qe::write_to_txt(std::vector<imageInfo> &img_info, const std::string &path)
{
	std::ofstream ofs(path);
	for (auto &img : img_info)
	{
		ofs << img.img_id << " " << img.img_path << " " << img.score << std::endl;
	}
	ofs.close();
}